{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 读取和存储\n",
    "\n",
    "到目前为止，我们介绍了如何处理数据以及如何构建、训练和测试深度学习模型。然而在实际中，我们有时需要把训练好的模型部署到很多不同的设备。在这种情况下，我们可以把内存中训练好的模型参数存储在硬盘上供后续读取使用。\n",
    "\n",
    "\n",
    "## 读写`Tensor`\n",
    "\n",
    "我们可以直接使用`save`函数和`load`函数分别存储和读取`Tensor`。下面的例子创建了`Tensor`变量`x`，并将其存在文件名同为`x`的文件里。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "x = torch.ones(3)\n",
    "torch.save(x, 'x')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后我们将数据从存储的文件读回内存。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1., 1., 1.])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x2 = torch.load('x')\n",
    "x2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们还可以存储一列`Tensor`并读回内存。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([1., 1., 1.]), tensor([0., 0., 0., 0.]))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y = torch.zeros(4)\n",
    "torch.save([x, y], 'xy')\n",
    "x2, y2 = torch.load('xy')\n",
    "(x2, y2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们甚至可以存储并读取一个从字符串映射到`Tensor`的字典。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mydict = {'x': x, 'y': y}\n",
    "torch.save(mydict, 'mydict')\n",
    "mydict2 = torch.load('mydict')\n",
    "mydict2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 读写模型的参数\n",
    "\n",
    "除`Tensor`以外，我们还可以读写模型的参数。我们可以使用`save`方法来保存模型的`state_dict`，`Module`类提供了`load_state_dict`函数来读取模型参数。为了演示方便，我们先创建一个多层感知机，并将其初始化。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(MLP, self).__init__(**kwargs)\n",
    "        self.hidden = nn.Linear(20, 256)\n",
    "        self.activation = nn.ReLU()\n",
    "        self.output = nn.Linear(256, 10)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.output(self.activation(self.hidden(x)))\n",
    "    \n",
    "net = MLP()\n",
    "X = torch.rand(2, 20)\n",
    "Y = net(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面把该模型的参数存成文件，文件名为mlp.params。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = 'mlp.params'\n",
    "torch.save(net.state_dict(), filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，我们再实例化一次定义好的多层感知机。与随机初始化模型参数不同，我们在这里直接读取保存在文件里的参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net2 = MLP()\n",
    "net2.load_state_dict(torch.load(filename))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "因为这两个实例都有同样的模型参数，那么对同一个输入`X`的计算结果将会是一样的。我们来验证一下。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[True, True, True, True, True, True, True, True, True, True],\n",
       "        [True, True, True, True, True, True, True, True, True, True]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y2 = net2(X)\n",
    "Y2 == Y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* 通过`save`函数和`load`函数可以很方便地读写`Tensor`。\n",
    "* 通过`load_state_dict`函数可以很方便地读取模型的参数。\n",
    "\n",
    "## 练习\n",
    "\n",
    "* 即使无须把训练好的模型部署到不同的设备，存储模型参数在实际中还有哪些好处？\n",
    "\n",
    "\n",
    "\n",
    "## 扫码直达[讨论区](https://discuss.gluon.ai/t/topic/1255)\n",
    "\n",
    "![](../img/qr_read-write.svg)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
